"""
inductive biases on data. 数据对解耦表征有着关键的影响。
我们认为数据的显著度影响了解耦。range参数指示了rotation、translation的范围。
range应该与threshold有关。KL与(H^(2/3)/Beta^(2/3))相关，β越大KL越小。
"""
import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.anneal import get_anneal
from disvae.models.losses import get_loss_s
from disvae.models.vae import VAE
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from exps.visualize import plt_sample_traversal
from utils.visualize import *
from exps.patterns import data_generator
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=256,
    learning_rate=1e-3,
    epochs=1001,
    dimension=1,
    loss='betaH',
    beta=16.0,
    scope_x=10.0,
    angle=90.0,
    img_id=0,
    random_seed=154
)
parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args

set_seed(config.random_seed)
wandb.init(project="exps", config=config, group='bias', tags=['beta_vs_KL'])


def get_preds(model, dl):
    model.eval()
    preds = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).cuda()
            mu, _ = model.encoder(img)
            preds.append(mu)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, targets


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    for e in trange(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img, label in dl:
            img = img.cuda()
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()


            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        if e % 100 == 0:
            wandb.log({
                'recon': wandb.Image(recon_batch[0, 0]),
                'img': wandb.Image(img[0, 0])
            })
            plt.close()
        storer['epoch'] = e
        wandb.log(storer, sync=False)
    return storer


# generate data
img_id = config.img_id
epochs = config.epochs
dim = config.dimension
beta = config.beta

img = torch.load(f'patterns/{img_id}.pat')
img = data_generator.rotate(img, config.angle)
dataset = data_generator.TranslationX(img, (config.scope_x))

dl = DataLoader(dataset,
                num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim)
vae.cuda()
vae.train()

# wandb.watch(vae, log_freq=10)
iterations = len(dl) * epochs
# anneal = get_anneal('monotonic', iterations, 1, 0)
anneal = get_anneal('constant', iterations, 1, 1)
loss_f = get_loss_s(config.loss, anneal, config.beta, len(dl.dataset))

storer = train(dl, vae, loss_f, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' == k[:3]])
_, index = kl_loss.sort(descending=True)

vae.eval()
vae.cpu()


def inversion(seq):
    s1 = 0
    s2 = 0
    for i in range(1, seq.size(1)):
        for j in range(i):
            s1 += (seq[:, j] > seq[:, i]).long()
            s2 += (seq[:, j] <= seq[:, i]).long()
    t = torch.stack([s1, s2])
    return t.min(0)[0], t.max(0)[0]


mu = vae.encoder(dataset.imgs)[0]
s_min, s_max = inversion(mu.transpose(0, 1))
fig = plt_sample_traversal(dataset.imgs[:1], vae, 7, index[:3], r=2)
wandb.log({'traversal': wandb.Image(fig)})
for k, v in {'s_min': s_min.float().mean().item(),
             's_max': s_max.float().mean().item(),
             'inv': (s_min.float() / s_max.float()).mean().item()}.items():
    wandb.summary.update({k: v})

plt.show()
